function [theta_hat, P_hat] = Refit(A, K, labels)

n = size(A,1);
Pi_hat = zeros(n, K); 
for k = 1:K
     Pi_hat(labels==k, k)=1;
end
A = A- diag(diag(A));
B = 1 - A;
B = B - diag(diag(B));

% Estimate theta
theta_hat = zeros(n,1);
for k = 1:K     
    A_block = A(labels==k, labels==k);
    B_block = B(labels==k, labels==k);
    AB_block = A_block*B_block;
    numerator = sum( transpose(AB_block .*transpose(A_block)) );
   %  fprintf('size of numerator: %f\n',size(numerator));
    denominator = sum( transpose(B_block.*transpose(AB_block)) );
    check_sign1 = sum(denominator == 0);
   % check_sign2 = sum(numerator == 0);
   %fprintf('number of 0 entries in denominator: %f\n',check_sign1);
   % fprintf('number of 0 entries in numerator: %f\n',check_sign2);
    temp_theta = numerator./denominator; %max(denominator, 10^(-15));
   % temp_theta(isnan(temp_theta)) = sum(sum(A))/n;
    % non_nan_ind =  find(~isnan(temp_theta));
    % nan_ind =  find(isnan(temp_theta));
    % disp(nan_ind);
    % temp_mean = mean(temp_theta(non_nan_ind));
    %  fprintf('temp_mean: %f\n',temp_mean);
    % temp_theta(isnan(temp_theta)) = temp_mean;
    % 
    % check_sign22 = sum(isnan(temp_theta));
    % fprintf('NaN after truncation: %f\n',check_sign22);
    %    check_sign23 = sum(temp_theta<0);
    % fprintf('negative after truncation: %f\n',check_sign23);
    theta_hat(labels==k) = (temp_theta).^0.5; 
end

% Estimate Pi 
P_hat = zeros(K,K);
TempMat = B.* (theta_hat * theta_hat');
Numerator = transpose(Pi_hat)*A*Pi_hat;
Denominator = transpose(Pi_hat)*TempMat*Pi_hat;
check_sign3 = sum(Denominator == 0);
fprintf('number of 0 entries in denominator in P: %f\n',check_sign3);
P_hat = Numerator./Denominator;
%P_hat(isnan(P_hat)) = 0;
%P_hat(P_hat == 0) = mean(mean(P_hat));
P_hat(1:size(P_hat,1)+1:end) = 1;

end